package testkey

import (
	"bytes"
	"crypto"
	"crypto/x509"
	"encoding/pem"
	"errors"
	"fmt"
	"os"
	"path/filepath"
	"runtime"
	"sync"

	"github.com/spiffe/spire/pkg/common/pemutil"
)

var (
	packageDir string
)

func init() {
	packageDir = initPackageDir()
}

func initPackageDir() string {
	_, file, _, ok := runtime.Caller(0)
	if !ok {
		panic("unable to obtain caller information")
	}
	return filepath.Dir(file)
}

type keyType[K crypto.Signer] interface {
	Path() string
	GenerateKey() (K, error)
}

type bucket[KT keyType[K], K crypto.Signer] struct {
	kt KT

	mtx  sync.Mutex
	keys []K
}

func (b *bucket[KT, K]) At(n int) (key K, err error) {
	b.mtx.Lock()
	defer b.mtx.Unlock()

	if err := b.load(); err != nil {
		return key, err
	}

	switch {
	case n > len(b.keys):
		return key, errors.New("cannot ask for key beyond the end")
	case n < len(b.keys):
		return b.keys[n], nil
	default:
		key, err = b.kt.GenerateKey()
		if err != nil {
			return key, err
		}
		b.keys = append(b.keys, key)
		if err := b.save(); err != nil {
			return key, err
		}
		return key, nil
	}
}

func (b *bucket[KT, K]) load() (err error) {
	if b.keys != nil {
		return nil
	}

	blocks, err := pemutil.LoadBlocks(b.path())
	if err != nil {
		if errors.Is(err, os.ErrNotExist) {
			return nil
		}
		return err
	}

	keys := make([]K, 0, len(blocks))
	for _, block := range blocks {
		key, ok := block.Object.(K)
		if !ok {
			return fmt.Errorf("expected %T; got %T", key, block.Object)
		}
		keys = append(keys, key)
	}

	b.keys = keys
	return nil
}

func (b *bucket[KT, K]) save() error {
	var buf bytes.Buffer
	buf.WriteString("// THIS FILE IS GENERATED. DO NOT EDIT THIS FILE DIRECTLY.\n\n")
	for _, key := range b.keys {
		keyBytes, err := x509.MarshalPKCS8PrivateKey(key)
		if err != nil {
			return err
		}
		_ = pem.Encode(&buf, &pem.Block{
			Type:  "PRIVATE KEY",
			Bytes: keyBytes,
		})
	}
	return os.WriteFile(b.path(), buf.Bytes(), 0600)
}

func (b *bucket[KT, K]) path() string {
	return filepath.Join(packageDir, b.kt.Path())
}
